{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.manual_seed(2618)\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '7'\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms, models\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "from datasets import Dataset\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import io\n",
    "import hashlib\n",
    "\n",
    "\n",
    "data_path = '/data02/users/lz/code/UICoder/datasets/c4-wash/c4-format-marked/merged'\n",
    "save_path = '/data02/users/lz/code/UICoder/checkpoints/classifier/c4_test.pth'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# from PIL import Image\n",
    "# from tqdm import tqdm\n",
    "# train_folder = '/data02/train_data/train'\n",
    "# dir = '/data02/train_data/data_'\n",
    "# class_ = 'others'\n",
    "# dir = dir + class_\n",
    "\n",
    "# os.makedirs(os.path.join(train_folder,class_),exist_ok=True)\n",
    "# for idx,filename in enumerate(tqdm(os.listdir(dir))):\n",
    "#     path = os.path.join(dir,filename)\n",
    "#     image = Image.open(path)\n",
    "#     image.save(os.path.join(train_folder,class_,filename))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def image_to_md5(image):\n",
    "#     image_bytes = io.BytesIO()\n",
    "#     image.save(image_bytes, format='PNG')\n",
    "#     image_data = image_bytes.getvalue()\n",
    "#     md5_hash = hashlib.md5(image_data)\n",
    "#     md5_hex = md5_hash.hexdigest()\n",
    "#     return str(md5_hex)\n",
    "\n",
    "# os.makedirs(imageFolder,exist_ok=True)\n",
    "# imageFolder = os.path.join(data_path,'imageFolder-balanced')\n",
    "# for i in range(0,6):\n",
    "#     os.makedirs(os.path.join(imageFolder,f'{i}'),exist_ok=True)\n",
    "\n",
    "# ds = Dataset.load_from_disk(data_path)\n",
    "# for item in tqdm(ds):\n",
    "#     item['image'].save(os.path.join(imageFolder,f\"{item['score_avr']}/{image_to_md5(item['image'])}.png\"))\n",
    "\n",
    "# for score in range(0,5):\n",
    "#     for idx,filename in enumerate(os.listdir(os.path.join(imageFolder, f'{score}'))):\n",
    "#         if idx >= count_min:\n",
    "#             file_path = os.path.join(os.path.join(imageFolder, f'{score}', filename))\n",
    "#             os.remove(file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3542, 3840, 1189, 634, 120, 5]\n"
     ]
    }
   ],
   "source": [
    "counts = []\n",
    "for score in range(0,6):\n",
    "    count = len(os.listdir(os.path.join(os.path.join(data_path,'imageFolder'),f'{score}')))\n",
    "    counts.append(count)\n",
    "print(counts)\n",
    "count_min = min(counts[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:26<00:00, 43.32it/s]"
     ]
    }
   ],
   "source": [
    "# 定义数据转换\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "inverse_transform = transforms.Compose([\n",
    "    transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],\n",
    "                         std=[1/0.229, 1/0.224, 1/0.225]),\n",
    "    transforms.ToPILImage()\n",
    "])\n",
    "\n",
    "# 加载完整的数据集\n",
    "train_dataset = datasets.ImageFolder(os.path.join(data_path,'imageFolder-balanced'), transform=transform)\n",
    "\n",
    "test_dataset = datasets.ImageFolder(os.path.join(data_path,'imageFolder'), transform=transform)\n",
    "val_dataset, test_dataset, _ = torch.utils.data.random_split(test_dataset, [500, 500, len(test_dataset)-1000])\n",
    "\n",
    "# 创建数据加载器\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=False)\n",
    "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5'}\n"
     ]
    }
   ],
   "source": [
    "class2idx = full_dataset.class_to_idx\n",
    "idx2class = {}\n",
    "for class_ in class2idx:\n",
    "    idx2class[class2idx[class_]] = class_ \n",
    "print(idx2class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10,Running Loss: 5.2777, End Loss:  16.035381\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  3.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 1.0366\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/10,Running Loss: 3.6500, End Loss:  12.305770\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  2.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 0.9192\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/10,Running Loss: 2.7216, End Loss:  9.728524\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  2.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 1.1373\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/10,Running Loss: 2.1461, End Loss:  7.873539\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  3.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 1.6252\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/10,Running Loss: 1.7725, End Loss:  6.453003\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  3.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 2.1254\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/10,Running Loss: 1.5022, End Loss:  5.230870\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  3.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 2.5926\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/10,Running Loss: 1.2876, End Loss:  4.280138\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  3.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 3.3518\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/10,Running Loss: 1.0093, End Loss:  3.383116\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  3.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 3.8055\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/10,Running Loss: 0.7690, End Loss:  2.568637\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  2.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 4.9870\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train: 100%|██████████| 38/38 [00:13<00:00,  2.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/10,Running Loss: 0.5443, End Loss:  1.811326\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval: 100%|██████████| 32/32 [00:10<00:00,  3.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Loss: 6.2626\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义模型\n",
    "model = models.resnet50(pretrained=True)\n",
    "num_ftrs = model.fc.in_features\n",
    "model.fc = nn.Linear(num_ftrs, 1)  # 输出为一个评分，因此输出维度为1\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)\n",
    "\n",
    "num_epochs = 10\n",
    "\n",
    "# 定义损失函数和优化器\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n",
    "scheduler = CosineAnnealingLR(optimizer, T_max=38*num_epochs)\n",
    "\n",
    "# 训练模型\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    for inputs, labels in tqdm(train_loader,desc=\"Train\"):\n",
    "        inputs, labels = inputs.to(device), labels.float().to(device)  # 转换为浮点型\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs.squeeze(), labels)  # 计算损失\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        \n",
    "    epoch_loss = running_loss / len(train_loader)\n",
    "    print(f\"Epoch {epoch+1}/{num_epochs},Running Loss: {epoch_loss:.4f}, End Loss: {loss.item(): 4f}\")\n",
    "\n",
    "    # 在验证集上评估模型\n",
    "    model.eval()\n",
    "    val_loss = 0.0\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in tqdm(val_loader,desc=\"Eval\"):\n",
    "            inputs, labels = inputs.to(device), labels.float().to(device)\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs.squeeze(), labels)\n",
    "            val_loss += loss.item()\n",
    "    \n",
    "    val_loss /= len(val_loader)\n",
    "    print(f\"Eval Loss: {val_loss:.4f}\\n\")\n",
    "\n",
    "    scheduler.step()\n",
    "\n",
    "# 保存模型\n",
    "torch.save(model.state_dict(), save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [01:40<00:00,  4.95it/s]\n",
      "/tmp/ipykernel_3953953/2819384816.py:17: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n",
      "  prediction = int(np.round(outputs.cpu().numpy()[0]))\n",
      "100%|█████████▉| 498/500 [00:12<00:00, 46.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.38\n",
      "G2 Acc: 0.0\n",
      "L1 Acc: 0.47\n",
      "G2L1 Acc: 0.38\n"
     ]
    }
   ],
   "source": [
    "# 加载模型\n",
    "model = models.resnet50()\n",
    "num_ftrs = model.fc.in_features\n",
    "model.fc = nn.Linear(num_ftrs, 1)\n",
    "\n",
    "model.to(device)\n",
    "model.load_state_dict(torch.load(save_path))\n",
    "model.eval()\n",
    "\n",
    "# 预测并保存结果\n",
    "predictions = []\n",
    "with torch.no_grad():\n",
    "    tbar = tqdm(total=len(test_dataset))\n",
    "    for inputs, label in test_dataset:\n",
    "        tbar.update(1)\n",
    "        outputs = model(torch.stack([inputs]).to(device))\n",
    "        prediction = int(np.round(outputs.cpu().numpy()[0]))\n",
    "        predictions.append([label,prediction])  # 假设模型输出是一个 numpy 数组\n",
    "        \n",
    "g2 = 0\n",
    "g2_right = 0\n",
    "l1 = 0\n",
    "l1_right = 0\n",
    "count = 0\n",
    "allowance = 0\n",
    "for item in predictions:\n",
    "    if item[1] < 0:\n",
    "        item[1] = 0\n",
    "    elif item[1] > 5:\n",
    "        item[1] = 5\n",
    "    if item[0] >= 2:\n",
    "        g2 += 1\n",
    "        if abs(item[0] - item[1]) <= allowance:\n",
    "            g2_right += 1\n",
    "    if item[0] <= 1:\n",
    "        l1 += 1\n",
    "        if abs(item[0] - item[1]) <= allowance:\n",
    "            l1_right += 1\n",
    "    if abs(item[0] - item[1]) <= allowance:\n",
    "        count += 1\n",
    "print(f'Acc: {round(count/len(predictions),2)}')\n",
    "print(f'G2 Acc: {round(g2_right/g2,2)}')\n",
    "print(f'L1 Acc: {round(l1_right/l1,2)}')\n",
    "print(f'G2L1 Acc: {round((g2_right+l1_right)/(g2+l1),2)}')\n",
    "\n",
    "\n",
    "# 将预测结果保存到文件\n",
    "with open('predictions.txt', 'w') as f:\n",
    "    for pred in predictions:\n",
    "        f.write(f\"{pred}\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "UICoder",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
